fadfcc
@@ -19,6 +19,7 @@
 package org.apache.hadoop.hive.ql.optimizer.stats.annotation;
 
 import java.lang.reflect.Field;
+import java.util.Arrays;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
@@ -1543,8 +1544,16 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
 
         // update join statistics
         stats.setColumnStats(outColStats);
-        long joinRowCount = inferredRowCount !=-1 ? inferredRowCount : computeNewRowCount(rowCounts, denom, jop);
-        updateColStats(conf, stats, joinRowCount, jop, rowCountParents);
+
+        // reason we compute interim row count, where join type isn't considered, is because later
+        // it will be used to estimate num nulls
+        long interimRowCount = inferredRowCount !=-1 ? inferredRowCount
+            :computeRowCountAssumingInnerJoin(rowCounts, denom, jop);
+        // final row computation will consider join type
+        long joinRowCount = inferredRowCount !=-1 ? inferredRowCount
+            :computeFinalRowCount(rowCounts, interimRowCount, jop);
+
+        updateColStats(conf, stats, interimRowCount, joinRowCount, jop, rowCountParents);
 
         // evaluate filter expression and update statistics
         if (joinRowCount != -1 && jop.getConf().getNoOuterJoin() &&
@@ -1775,7 +1784,9 @@
private long getCardinality(List<Operator<? extends OperatorDesc>> ops, Integer
         newNumRows = newrows;
       } else {
         // there is more than one FK
-        newNumRows = this.computeNewRowCount(rowCounts, getDenominator(distinctVals), jop);
+        newNumRows = this.computeRowCountAssumingInnerJoin(rowCounts,
+            getDenominator(distinctVals), jop);
+        newNumRows = this.computeFinalRowCount(rowCounts, newNumRows, jop);
       }
       return newNumRows;
     }
@@ -1895,7 +1906,85 @@
private float getSelectivityComplexTree(Operator<? extends OperatorDesc> op) {
       return result;
     }
 
-    private void updateColStats(HiveConf conf, Statistics stats, long newNumRows,
+    private boolean isJoinKey(final String columnName,
+    final ExprNodeDesc[][] joinKeys) {
+      for (int i = 0; i < joinKeys.length; i++) {
+        for (ExprNodeDesc expr : Arrays.asList(joinKeys[i])) {
+
+          if (expr instanceof ExprNodeColumnDesc) {
+            if (((ExprNodeColumnDesc) expr).getColumn().equals(columnName)) {
+              return true;
+            }
+          }
+        }
+      }
+      return false;
+    }
+
+    private void updateNumNulls(ColStatistics colStats, long interimNumRows, long newNumRows,
+        long pos, CommonJoinOperator<? extends JoinDesc> jop) {
+
+      if (!(jop.getConf().getConds().length == 1)) {
+        // TODO: handle multi joins
+        return;
+      }
+
+
+      long oldNumNulls = colStats.getNumNulls();
+      long newNumNulls = Math.min(newNumRows, oldNumNulls);
+
+      JoinCondDesc joinCond = jop.getConf().getConds()[0];
+      switch (joinCond.getType()) {
+      case JoinDesc.LEFT_OUTER_JOIN :
+        //if this column is coming from right input only then we update num nulls
+        if(pos == joinCond.getRight()
+            && interimNumRows != newNumRows) {
+          // interim row count can not be less due to containment
+          // assumption in join cardinality computation
+          assert(newNumRows > interimNumRows);
+          if(isJoinKey(colStats.getColumnName(), jop.getConf().getJoinKeys())) {
+            newNumNulls = Math.min(newNumRows,  (newNumRows-interimNumRows));
+          }
+          else {
+            newNumNulls = Math.min(newNumRows, oldNumNulls + (newNumRows-interimNumRows));
+          }
+        }
+        break;
+      case JoinDesc.RIGHT_OUTER_JOIN:
+        if(pos == joinCond.getLeft()
+            && interimNumRows != newNumRows) {
+
+          // interim row count can not be less due to containment
+          // assumption in join cardinality computation
+          // interimNumRows represent number of matches for join keys on two sides.
+          // newNumRows-interimNumRows represent number of non-matches.
+          assert(newNumRows > interimNumRows);
+
+          if (isJoinKey(colStats.getColumnName(), jop.getConf().getJoinKeys())) {
+            newNumNulls = Math.min(newNumRows, (newNumRows - interimNumRows));
+          } else {
+            newNumNulls = Math.min(newNumRows, oldNumNulls + (newNumRows - interimNumRows));
+          }
+        }
+        break;
+      case JoinDesc.FULL_OUTER_JOIN:
+        if (isJoinKey(colStats.getColumnName(), jop.getConf().getJoinKeys())) {
+          newNumNulls = Math.min(newNumRows, (newNumRows - interimNumRows));
+        } else {
+          newNumNulls = Math.min(newNumRows, oldNumNulls + (newNumRows - interimNumRows));
+        }
+        break;
+
+      case JoinDesc.INNER_JOIN:
+      case JoinDesc.UNIQUE_JOIN:
+      case JoinDesc.LEFT_SEMI_JOIN:
+        break;
+      }
+      colStats.setNumNulls(newNumNulls);
+    }
+
+    private void updateColStats(HiveConf conf, Statistics stats, long interimNumRows,
+        long newNumRows,
         CommonJoinOperator<? extends JoinDesc> jop,
         Map<Integer, Long> rowCountParents) {
 
@@ -1934,10 +2023,9 @@
private void updateColStats(HiveConf conf, Statistics stats, long newNumRows,
         if (ratio <= 1.0) {
           newDV = (long) Math.ceil(ratio * oldDV);
         }
-        // Assumes inner join
-        // TODO: HIVE-5579 will handle different join types
-        cs.setNumNulls(0);
+
         cs.setCountDistint(newDV);
+        updateNumNulls(cs, interimNumRows, newNumRows, pos, jop);
       }
       stats.setColumnStats(colStats);
       long newDataSize = StatsUtils
@@ -1956,7 +2044,41 @@
private void updateColStats(HiveConf conf, Statistics stats, long newNumRows,
       stats.setDataSize(StatsUtils.getMaxIfOverflow(newDataSize));
     }
 
-    private long computeNewRowCount(List<Long> rowCountParents, long denom, CommonJoinOperator<? extends JoinDesc> join) {
+    private long computeFinalRowCount(List<Long> rowCountParents, long interimRowCount,
+        CommonJoinOperator<? extends JoinDesc> join) {
+      long result = interimRowCount;
+      if (join.getConf().getConds().length == 1) {
+        JoinCondDesc joinCond = join.getConf().getConds()[0];
+        switch (joinCond.getType()) {
+        case JoinDesc.INNER_JOIN:
+          // only dealing with special join types here.
+          break;
+        case JoinDesc.LEFT_OUTER_JOIN :
+          // all rows from left side will be present in resultset
+          result = Math.max(rowCountParents.get(joinCond.getLeft()), result);
+          break;
+        case JoinDesc.RIGHT_OUTER_JOIN :
+          // all rows from right side will be present in resultset
+          result = Math.max(rowCountParents.get(joinCond.getRight()), result);
+          break;
+        case JoinDesc.FULL_OUTER_JOIN :
+          // all rows from both side will be present in resultset
+          result = Math.max(StatsUtils.safeAdd(rowCountParents.get(joinCond.getRight()),
+              rowCountParents.get(joinCond.getLeft())), result);
+          break;
+        case JoinDesc.LEFT_SEMI_JOIN :
+          // max # of rows = rows from left side
+          result = Math.min(rowCountParents.get(joinCond.getLeft()), result);
+          break;
+        default:
+          LOG.debug("Unhandled join type in stats estimation: " + joinCond.getType());
+          break;
+        }
+      }
+      return result;
+    }
+    private long computeRowCountAssumingInnerJoin(List<Long> rowCountParents, long denom,
+        CommonJoinOperator<? extends JoinDesc> join) {
       double factor = 0.0d;
       long result = 1;
       long max = rowCountParents.get(0);
@@ -1982,33 +2104,6 @@
private long computeNewRowCount(List<Long> rowCountParents, long denom, CommonJo
 
       result = (long) (result * factor);
 
-      if (join.getConf().getConds().length == 1) {
-        JoinCondDesc joinCond = join.getConf().getConds()[0];
-        switch (joinCond.getType()) {
-          case JoinDesc.INNER_JOIN:
-            // only dealing with special join types here.
-            break;
-          case JoinDesc.LEFT_OUTER_JOIN :
-            // all rows from left side will be present in resultset
-            result = Math.max(rowCountParents.get(joinCond.getLeft()),result);
-            break;
-          case JoinDesc.RIGHT_OUTER_JOIN :
-            // all rows from right side will be present in resultset
-            result = Math.max(rowCountParents.get(joinCond.getRight()),result);
-            break;
-          case JoinDesc.FULL_OUTER_JOIN :
-            // all rows from both side will be present in resultset
-            result = Math.max(StatsUtils.safeAdd(rowCountParents.get(joinCond.getRight()), rowCountParents.get(joinCond.getLeft())),result);
-            break;
-          case JoinDesc.LEFT_SEMI_JOIN :
-            // max # of rows = rows from left side
-            result = Math.min(rowCountParents.get(joinCond.getLeft()),result);
-            break;
-          default:
-            LOG.debug("Unhandled join type in stats estimation: " + joinCond.getType());
-            break;
-        }
-      }
       return result;
     }
 
